from model_service.pytorch_model_service import PTServingBaseService
from data import *
import cv2
from PIL import Image
import numpy as np
import torch
import os
from models import *


class garbage_classify_service(PTServingBaseService):
# class garbage_classify_service(object):
    def __init__(self, model_name, model_path):
        super(garbage_classify_service, self).__init__(model_name, model_path)
        # these three parameters are no need to modify
        self.model_name = model_name
        self.model_path = model_path
        self.signature_key = 'predict_images'
        self.dir_path = os.path.dirname(os.path.realpath(model_path))
        self.input_size = 224  # the input image size of the model

        self.input_key_1 = 'input_img'
        self.output_key_1 = 'output_score'

        self.transform = BaseTransform(self.input_size,
                                         rgb_means=(138.11617731, 128.38959552, 116.94768342),
                                         rgb_std=(52.90101662, 54.29838, 56.22659914))
        self.model = load_model_for_test(model_path, 43)
        self.label_id_name_dict = \
        {
            "0": "其他垃圾/一次性快餐盒",
            "1": "其他垃圾/污损塑料",
            "2": "其他垃圾/烟蒂",
            "3": "其他垃圾/牙签",
            "4": "其他垃圾/破碎花盆及碟碗",
            "5": "其他垃圾/竹筷",
            "6": "厨余垃圾/剩饭剩菜",
            "7": "厨余垃圾/大骨头",
            "8": "厨余垃圾/水果果皮",
            "9": "厨余垃圾/水果果肉",
            "10": "厨余垃圾/茶叶渣",
            "11": "厨余垃圾/菜叶菜根",
            "12": "厨余垃圾/蛋壳",
            "13": "厨余垃圾/鱼骨",
            "14": "可回收物/充电宝",
            "15": "可回收物/包",
            "16": "可回收物/化妆品瓶",
            "17": "可回收物/塑料玩具",
            "18": "可回收物/塑料碗盆",
            "19": "可回收物/塑料衣架",
            "20": "可回收物/快递纸袋",
            "21": "可回收物/插头电线",
            "22": "可回收物/旧衣服",
            "23": "可回收物/易拉罐",
            "24": "可回收物/枕头",
            "25": "可回收物/毛绒玩具",
            "26": "可回收物/洗发水瓶",
            "27": "可回收物/玻璃杯",
            "28": "可回收物/皮鞋",
            "29": "可回收物/砧板",
            "30": "可回收物/纸板箱",
            "31": "可回收物/调料瓶",
            "32": "可回收物/酒瓶",
            "33": "可回收物/金属食品罐",
            "34": "可回收物/锅",
            "35": "可回收物/食用油桶",
            "36": "可回收物/饮料瓶",
            "37": "有害垃圾/干电池",
            "38": "有害垃圾/软膏",
            "39": "有害垃圾/过期药物",
            "40": "可回收物/毛巾",
            "41": "可回收物/饮料盒",
            "42": "可回收物/纸袋"
        }
    def _preprocess(self, data):
        preprocessed_data = {}
        for k, v in data.items():
            for file_name, file_content in v.items():
                # img = cv2.imread(file_content, cv2.IMREAD_COLOR)
                # # to rgb
                # img = img[:,:,(2,1,0)]
                img = Image.open(file_content)
                img = img.convert('RGB')
                img = self.transform(np.array(img))
                preprocessed_data[k] = torch.from_numpy(img).permute(2,0,1)
        return preprocessed_data      

    def calc_correct_index(self, predict):
    
        softmax = torch.nn.Softmax(dim = 1)
        out = softmax(predict)
        pred_index = out.max(dim=1)[1]
        return pred_index.item()

    def _inference(self, data):
        img = data[self.input_key_1]
        img = img.unsqueeze(0)
        with torch.no_grad():
            if torch.cuda.is_available(): 
                out = self.model(img.cuda().float())
            else:
                out = self.model(img.float())
        if out is not None:
            pred_index = self.calc_correct_index(out)
            result = {'result': self.label_id_name_dict[str(pred_index)]}
        else:
            result = {'result': 'predict score is None'}
        return result
        
    def _postprocess(self, data):
        return data        

    def inference(self, data):
        data = self._preprocess(data)
        data = self._inference(data)
        data = self._postprocess(data)
        return data


def load_model_for_test(model_path, num_classes=43):

    model_name = os.path.split(model_path)[1]
    model_name = model_name.split('_')[0]
    # 生成网络
    mynet = MyResNet(model_name, num_classes)
    # 加载模型
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    mynet.load_state_dict(checkpoint)

    if torch.cuda.is_available():
        mynet = mynet.cuda()


    if torch.cuda.is_available():
        mynet = torch.nn.DataParallel(mynet).cuda()
    else:
        mynet = torch.nn.DataParallel(mynet)

    mynet.eval()

    return mynet

if __name__ == '__main__':
    m  = garbage_classify_service('resnet101', r'D:\develop\huawei cloud\trash_sort\pytorch-resnet-for-huawei-modelarts\models\resnet101_3.pth')
    data = {'input_img': {'0' : r'D:\develop\huawei cloud\trash_dataset\train_data\img_6.jpg'}}
    result = m.inference(data)
    print(result)
# class garbage_classify_service(PTServingBaseService):
#     def __init__(self, model_name, model_path):
#         # these three parameters are no need to modify
#         self.model_name = model_name
#         self.model_path = model_path
#         self.signature_key = 'predict_images'

#         self.input_size = 224  # the input image size of the model

#         self.input_key_1 = 'input_img'
#         self.output_key_1 = 'output_score'

#         self.transform = BaseTransform(self.input_size, (0, 0, 0))
#         self.net = MyResNet('resnet101', 43)
#         self.net.load_state_dict(os.path.join(self.model_path, self.model_name))
#         self.net.eval()
#         self.label_id_name_dict = \
#         {
#             "0": "其他垃圾/一次性快餐盒",
#             "1": "其他垃圾/污损塑料",
#             "2": "其他垃圾/烟蒂",
#             "3": "其他垃圾/牙签",
#             "4": "其他垃圾/破碎花盆及碟碗",
#             "5": "其他垃圾/竹筷",
#             "6": "厨余垃圾/剩饭剩菜",
#             "7": "厨余垃圾/大骨头",
#             "8": "厨余垃圾/水果果皮",
#             "9": "厨余垃圾/水果果肉",
#             "10": "厨余垃圾/茶叶渣",
#             "11": "厨余垃圾/菜叶菜根",
#             "12": "厨余垃圾/蛋壳",
#             "13": "厨余垃圾/鱼骨",
#             "14": "可回收物/充电宝",
#             "15": "可回收物/包",
#             "16": "可回收物/化妆品瓶",
#             "17": "可回收物/塑料玩具",
#             "18": "可回收物/塑料碗盆",
#             "19": "可回收物/塑料衣架",
#             "20": "可回收物/快递纸袋",
#             "21": "可回收物/插头电线",
#             "22": "可回收物/旧衣服",
#             "23": "可回收物/易拉罐",
#             "24": "可回收物/枕头",
#             "25": "可回收物/毛绒玩具",
#             "26": "可回收物/洗发水瓶",
#             "27": "可回收物/玻璃杯",
#             "28": "可回收物/皮鞋",
#             "29": "可回收物/砧板",
#             "30": "可回收物/纸板箱",
#             "31": "可回收物/调料瓶",
#             "32": "可回收物/酒瓶",
#             "33": "可回收物/金属食品罐",
#             "34": "可回收物/锅",
#             "35": "可回收物/食用油桶",
#             "36": "可回收物/饮料瓶",
#             "37": "有害垃圾/干电池",
#             "38": "有害垃圾/软膏",
#             "39": "有害垃圾/过期药物",
#             "40": "可回收物/毛巾",
#             "41": "可回收物/饮料盒",
#             "42": "可回收物/纸袋"
#         }

    # def _preprocess(self, data):
    #     preprocessed_data = {}
    #     for k, v in data.items():
    #         for file_name, file_content in v.items():
    #             img = cv2.imread(file_content, cv2.IMREAD_COLOR)
    #             # to rgb
    #             img = img[:,:,(2,1,0)]
    #             img = self.transform(img)
    #             preprocessed_data[k] = torch.from_numpy(img).permute(2,0,1)
    #     return preprocessed_data          

#     def calc_correct_index(self, predict):
    
#         softmax = torch.nn.Softmax(dim = 1)
#         out = softmax(predict)
#         pred_index = out.max(dim=1)[1]
#         return pred_index


#     def _inference(self, data):
#         img = data[self.input_key_1]
#         img = img.unsqueeze(0).cpu()
#         out = self.net(img)
#         if out is not None:
#             pred_index = self.calc_correct_index(out)
#             result = {'result': self.label_id_name_dict[str(pred_index)]}
#         else:
#             result = {'result': 'predict score is None'}
#         return result
        
#     def _postprocess(self, data):
#         return data        
